{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import cv2\n",
    "import torch\n",
    "from time import time\n",
    "import os\n",
    "import datetime\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import torch.nn.functional as F\n",
    "from torch.utils import data\n",
    "from image_utils import dense_warp, warp\n",
    "device = 'cuda'\n",
    "height,width = 360,640\n",
    "batch_size = 1\n",
    "grid_h,grid_w = 15,15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_warp(net_out,img):\n",
    "    '''\n",
    "    Inputs:\n",
    "        net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])\n",
    "        img: image to warp\n",
    "    '''\n",
    "    grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n",
    "                                    torch.linspace(-1,1, grid_h + 1),\n",
    "                                    indexing='ij')\n",
    "    src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n",
    "    new_grid = src_grid + net_out\n",
    "    grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)\n",
    "    warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=False,padding_mode='zeros')\n",
    "    return warped"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StabNet(nn.Module):\n",
    "    def __init__(self,trainable_layers = 10):\n",
    "        super(StabNet, self).__init__()\n",
    "        # Load the pre-trained ResNet model\n",
    "        vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')\n",
    "        # Extract conv1 pretrained weights for RGB input\n",
    "        rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])\n",
    "        # Calculate the average across the RGB channels\n",
    "        average_rgb_weights = torch.mean(rgb_weights, dim=1, keepdim=True).repeat(1,6,1,1)  #torch.Size([64, 5, 7, 7])\n",
    "        # Change size of the first layer from 3 to 9 channels\n",
    "        vgg19.features[0] = nn.Conv2d(9,64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        # set new weights\n",
    "        new_weights = torch.cat((rgb_weights, average_rgb_weights), dim=1)\n",
    "        vgg19.features[0].weight = nn.Parameter(new_weights)\n",
    "        # Determine the total number of layers in the model\n",
    "        total_layers = sum(1 for _ in vgg19.parameters())\n",
    "        # Freeze the layers except the last 10\n",
    "        for idx, param in enumerate(vgg19.parameters()):\n",
    "            if idx > total_layers - trainable_layers:\n",
    "                param.requires_grad = True\n",
    "            else:\n",
    "                param.requires_grad = False\n",
    "        # Remove the last layer of ResNet\n",
    "        self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])\n",
    "        self.regressor = nn.Sequential(nn.Linear(512,2048),\n",
    "                                       nn.ReLU(),\n",
    "                                       nn.Linear(2048,1024),\n",
    "                                       nn.ReLU(),\n",
    "                                       nn.Linear(1024,512),\n",
    "                                       nn.ReLU(),\n",
    "                                       nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))\n",
    "        total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)\n",
    "        total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)\n",
    "        print(\"Total Trainable encoder Parameters: \", total_resnet_params)\n",
    "        print(\"Total Trainable regressor Parameters: \", total_regressor_params)\n",
    "        print(\"Total Trainable parameters:\",total_regressor_params + total_resnet_params)\n",
    "    \n",
    "    def forward(self, x_tensor):\n",
    "        x_batch_size = x_tensor.size()[0]\n",
    "        x = x_tensor[:, :3, :, :]\n",
    "\n",
    "        # summary 1, dismiss now\n",
    "        x_tensor = self.encoder(x_tensor)\n",
    "        x_tensor = torch.mean(x_tensor, dim=[2, 3])\n",
    "        x = self.regressor(x_tensor)\n",
    "        x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total Trainable encoder Parameters:  2360320\n",
      "Total Trainable regressor Parameters:  3936256\n",
      "Total Trainable parameters: 6296576\n",
      "loaded weights ./ckpts/original/stabnet_2023-10-26_13-42-14.pth\n"
     ]
    }
   ],
   "source": [
    "ckpt_dir = './ckpts/original/'\n",
    "stabnet = StabNet().to(device).eval()\n",
    "ckpts = os.listdir(ckpt_dir)\n",
    "if ckpts:\n",
    "    ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], \"%H-%M-%S\"), reverse=True)\n",
    "    \n",
    "    # Get the filename of the latest checkpoint\n",
    "    latest = os.path.join(ckpt_dir, ckpts[0])\n",
    "\n",
    "    state = torch.load(latest)\n",
    "    stabnet.load_state_dict(state['model'])\n",
    "    print('loaded weights',latest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = 'E:/Datasets/DeepStab_Dataset/unstable/2.avi'\n",
    "cap = cv2.VideoCapture(path)\n",
    "frames = []\n",
    "while True:\n",
    "    ret,frame = cap.read()\n",
    "    if not ret : break\n",
    "    frame = cv2.resize(frame,(width,height))\n",
    "    frames.append(frame)\n",
    "frames = np.array(frames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([447, 3, 360, 640])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "frames_t = torch.from_numpy(frames/255.0).permute(0,3,1,2).float()\n",
    "frames_t.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[9], line 14\u001b[0m\n\u001b[0;32m     12\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m     13\u001b[0m     trasnform \u001b[38;5;241m=\u001b[39m stabnet(net_in)\n\u001b[1;32m---> 14\u001b[0m     warped \u001b[38;5;241m=\u001b[39m get_warp(trasnform \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m1\u001b[39m ,\u001b[43mcurr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m     15\u001b[0m     warped_frames[idx:idx\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m] \u001b[38;5;241m=\u001b[39m warped\u001b[38;5;241m.\u001b[39mcpu()\n\u001b[0;32m     16\u001b[0m     warped_gray \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmean(warped,dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m,keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "num_frames,_,h,w = frames_t.shape\n",
    "warped_frames = frames_t.clone()\n",
    "buffer = torch.zeros((6,1,h,w)).float()\n",
    "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
    "start = time()\n",
    "for iter in range(1):\n",
    "    for idx in range(33,num_frames):\n",
    "        for i in range(6):\n",
    "            buffer[i,...] = torch.mean(warped_frames[idx - 2**i,...],dim = 0,keepdim = True)\n",
    "        curr = warped_frames[idx:idx+1,...] \n",
    "        net_in = torch.cat([curr,buffer.permute(1,0,2,3)], dim = 1).to(device)\n",
    "        with torch.no_grad():\n",
    "            trasnform = stabnet(net_in)\n",
    "            warped = get_warp(trasnform * 1 ,curr.to(device))\n",
    "            warped_frames[idx:idx+1,...] = warped.cpu()\n",
    "            warped_gray = torch.mean(warped,dim = 1,keepdim=True)\n",
    "            buffer = torch.roll(buffer, shifts= 1, dims=1)\n",
    "            buffer[:,:1,:,:] = warped_gray\n",
    "            img = warped_frames[idx,...].permute(1,2,0).numpy()\n",
    "            img = (img * 255).astype(np.uint8)\n",
    "            cv2.imshow('window',img)\n",
    "            if cv2.waitKey(1) & 0xFF == ord(' '):\n",
    "                break\n",
    "cv2.destroyAllWindows()\n",
    "total = time() - start\n",
    "speed = total / num_frames\n",
    "print(f'speed: {speed} seconds per frame')  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from time import sleep\n",
    "fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n",
    "out = cv2.VideoWriter('./results/2.avi', fourcc, 30.0, (256,256))\n",
    "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
    "for idx in range(num_frames):\n",
    "    img = warped_frames[idx,...].permute(1,2,0).numpy()\n",
    "    img = (img * 255).astype(np.uint8)\n",
    "    diff = cv2.absdiff(img,frames[idx,...])\n",
    "    out.write(img)\n",
    "    cv2.imshow('window',img)\n",
    "    sleep(1/30)\n",
    "    if cv2.waitKey(1) & 0xFF == ord(' '):\n",
    "        break\n",
    "cv2.destroyAllWindows()\n",
    "out.release()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
    "for idx in range(num_frames):\n",
    "    img = warped_frames[idx,...].permute(1,2,0).numpy()\n",
    "    img = (img * 255).astype(np.uint8)\n",
    "    diff = cv2.absdiff(img,frames[idx,...])\n",
    "    cv2.imshow('window',diff)\n",
    "    sleep(1/30)\n",
    "    if cv2.waitKey(1) & 0xFF == ord(' '):\n",
    "        break\n",
    "cv2.destroyAllWindows()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Frame: 446/447\n",
      "cropping score:1.000\tdistortion score:0.989\tstability:0.639\tpixel:0.997\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(1.0, 0.98863894, 0.6389072784994596, 0.99749401723966)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from metrics import metric\n",
    "metric('E:/Datasets/DeepStab_Dataset/unstable/2.avi','./results/Regular_2.avi')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "DUTCode",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
